"""Common CKI logging functions."""
import contextlib
import datetime
import json
import logging
import logging.handlers
import os
import re
import sys
import threading
import time
import typing

CKI_LOGS_FILEDIR = '/tmp/cki-lib.log/'
STREAM = sys.stderr
FORMAT = '%(asctime)s - [%(levelname)s] - %(name)s - %(message)s'
LOCK = threading.Lock()
CKI_HANDLER = False

LOGGING_ENV = threading.local()


@contextlib.contextmanager
def logging_env(variables: typing.Dict[str, typing.Any]) -> typing.Iterator[None]:
    """Set logging environment variables."""
    LOGGING_ENV.__dict__.update(variables)
    try:
        yield
    finally:
        for variable in variables:
            del LOGGING_ENV.__dict__[variable]


class JSONFormatter(logging.Formatter):
    """JSON logging formatter."""

    def format(self, record: logging.LogRecord) -> str:
        """Format record into json."""
        content = {
            'timestamp': record.created,
            'logger': {
                'name': record.name,
                'level': record.levelname,
            },
            'message': super().format(record),
            'extras': LOGGING_ENV.__dict__,
        }
        return json.dumps(content, default=self.default)

    @staticmethod
    def default(obj: typing.Any) -> str:
        """Safely try to dump any object."""
        if isinstance(obj, datetime.datetime):
            return obj.isoformat()
        try:
            return str(obj)
        except Exception:
            return f'invalid: {obj.__class__.__name__}'


class PackFormatter(logging.Formatter):
    """JSON logging formatter for Loki '| unpack'."""

    def format(self, record: logging.LogRecord) -> str:
        """Format record into json."""
        content = {
            'timestamp': str(record.created),
            'logger_name': record.name,
            'logger_level': record.levelname,
            '_entry': super().format(record),
        }
        self.flatten(content, 'extras', LOGGING_ENV.__dict__)
        return json.dumps(content)

    @classmethod
    def flatten(
        cls,
        output: typing.Dict[str, str],
        prefix: str,
        obj: typing.Dict[str, typing.Any],
    ) -> None:
        """Safely try to dump any object."""
        prefix = re.sub('[^a-zA-Z0-9_]', '_', prefix)
        try:
            if isinstance(obj, dict):
                for key, value in obj.items():
                    cls.flatten(output, f'{prefix}_{key}', value)
            elif isinstance(obj, list):
                for key, value in enumerate(obj):
                    cls.flatten(output, f'{prefix}_{key}', value)
            elif isinstance(obj, datetime.datetime):
                output[prefix] = obj.isoformat()
            else:
                output[prefix] = str(obj)
        except Exception:
            output[prefix] = f'invalid: {obj.__class__.__name__}'


def get_logger(logger_name: str) -> logging.Logger:
    """Return CKI logger or descendant."""
    # Make sure adding handler is thread safe since get_logger might not be
    # called from main
    # https://docs.python.org/3/library/logging.html#logging.basicConfig
    with LOCK:
        global CKI_HANDLER  # pylint: disable=global-statement
        cki_handler = CKI_HANDLER
        CKI_HANDLER = True

    if not cki_handler:
        # Add STREAM handler to root logger.
        root_logger = logging.getLogger()

        if file_name := os.environ.get('CKI_LOGGING_FILE'):
            handler = logging.handlers.WatchedFileHandler(file_name)
        else:
            handler = logging.StreamHandler(STREAM)
        root_logger.addHandler(handler)
        logging_format = os.environ.get('CKI_LOGGING_FORMAT', 'plain')
        if logging_format == 'json':
            formatter: logging.Formatter = JSONFormatter()
        elif logging_format == 'pack':
            formatter = PackFormatter()
        else:
            formatter = TextFormatter(fmt=FORMAT)
        handler.setFormatter(formatter)

        # Set cki loglevel to CKI_LOGGING_LEVEL.
        cki_logger = logging.getLogger('cki')
        cki_logger.setLevel(os.environ.get('CKI_LOGGING_LEVEL', 'WARNING').upper())

    if not (logger_name == 'cki' or logger_name.startswith('cki.')):
        logger_name = 'cki.' + logger_name

    return logging.getLogger(logger_name)


class TextFormatter(logging.Formatter):
    """Log timestamps with more precision."""

    def formatTime(self, record: logging.LogRecord, datefmt: typing.Optional[str] = None) -> str:
        """Return the creation time of the specified LogRecord as formatted text."""
        time_record = time.gmtime(record.created)
        return f'{time.strftime("%Y-%m-%dT%H:%M:%S", time_record)}.{int(record.msecs * 1000):06d}'
